# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import torch

from mmdet.utils import util_mixins


class SamplingResult(util_mixins.NiceRepr):
    """Bbox sampling result.

    Example:
        >>> # xdoctest: +IGNORE_WANT
        >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA
        >>> self = SamplingResult.random(rng=10)
        >>> print(f'self = {self}')
        self = <SamplingResult({
            'neg_bboxes': torch.Size([12, 4]),
            'neg_inds': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12]),
            'num_gts': 4,
            'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
            'pos_bboxes': torch.Size([0, 4]),
            'pos_inds': tensor([], dtype=torch.int64),
            'pos_is_gt': tensor([], dtype=torch.uint8)
        })>
    """
    def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
                 gt_flags):
        self.pos_inds = pos_inds
        self.neg_inds = neg_inds
#         print('pos inds size:',pos_inds.size(),pos_inds.dtype)
#         print('neg_inds size:',neg_inds.size(),neg_inds.dtype)
#         print('bbox size:',bboxes.size())
#         print(torch.npu.synchronize(),'==================B0')
        self.pos_bboxes = bboxes * pos_inds.unsqueeze(1)
        self.neg_bboxes = bboxes * neg_inds.unsqueeze(1)
#         print('==========bboxes size:',bboxes.size())
#         print('==========self.pos_bboxes size:',self.pos_bboxes.size())
#         print('==========self.neg_bboxes size:',self.neg_bboxes.size())
        self.pos_is_gt = gt_flags * pos_inds
#         print(torch.npu.synchronize(),'==================B1')
        self.num_gts = gt_bboxes.shape[0]
#         print('assign_result.gt_inds size:',assign_result.gt_inds.size())
#         print('pos_inds size:',pos_inds.size())
        self.pos_assigned_gt_inds = (assign_result.gt_inds.int() - 1) * pos_inds
#         print('gt_inds:',assign_result.gt_inds)
#         for x in assign_result.gt_inds:
#             if x > 0:
#                 print(x)
#         print('pos_inds:',self.pos_inds)
#         print('pos_assigned_gt_inds:',self.pos_assigned_gt_inds)
#         for x in self.pos_assigned_gt_inds:
#             if x > 0:
#                 print(x)
#         print('pos_bboxes:',self.pos_bboxes)
#         print('neg_bboxes:',self.neg_bboxes)
#         print('pos_is_gt:',self.pos_is_gt,self.pos_is_gt.size())
#         print('gt_bboxes:',gt_bboxes,gt_bboxes.shape,len(gt_bboxes.shape))
#         print(torch.npu.synchronize(),'==================B2')
        if gt_bboxes.numel() == 0:
            # hack for index error case
            assert self.pos_assigned_gt_inds.numel() == 0
            self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
        else:
            if len(gt_bboxes.shape) < 2:
                gt_bboxes = gt_bboxes.view(-1, 4)
#             print('gt box:',gt_bboxes,gt_bboxes.size())
#             print('idx sel:',self.pos_assigned_gt_inds)
#             print(torch.npu.synchronize(),'==================B2.2:',gt_bboxes.size(),self.pos_assigned_gt_inds.size())
            self.pos_gt_bboxes = torch.index_select(gt_bboxes, 0, self.pos_assigned_gt_inds)
#             print(torch.npu.synchronize(),'==================B2.3:',self.pos_gt_bboxes.size())
#             print(torch.npu.synchronize(),'==================B2.4:',self.pos_bboxes)
#         print(torch.npu.synchronize(),'==================B3')
        if assign_result.labels is not None:
            self.pos_gt_labels = (assign_result.labels.int() * pos_inds).long()
        else:
            self.pos_gt_labels = None
#     def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
#                  gt_flags):
#         self.pos_inds = pos_inds
#         self.neg_inds = neg_inds
#         self.pos_bboxes = bboxes[pos_inds]
#         self.neg_bboxes = bboxes[neg_inds]
#         self.pos_is_gt = gt_flags[pos_inds]
        
#         self.num_gts = gt_bboxes.shape[0]
#         self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
#         # print('self.pos_assigned_gt_inds: ', self.pos_assigned_gt_inds.shape, self.pos_assigned_gt_inds.dtype, self.pos_assigned_gt_inds.device)
#         # print('assign_result:',type(assign_result),assign_result)
#         # print('gt_inds:',assign_result.gt_inds)
#         # print('pos_inds:',self.pos_inds)
#         # print('pos_assigned_gt_inds:',self.pos_assigned_gt_inds,self.pos_assigned_gt_inds.dtype)
#         # print('pos_bboxes:',self.pos_bboxes)
#         # print('neg_bboxes:',self.neg_bboxes)
#         # print('pos_is_gt:',self.pos_is_gt)
#         # print('gt_bboxes:',gt_bboxes,gt_bboxes.shape,len(gt_bboxes.shape))
#         if gt_bboxes.numel() == 0:
#             # hack for index error case
#             assert self.pos_assigned_gt_inds.numel() == 0
#             self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
#         else:
#             # print(torch.npu.synchronize(),'==================b1.12')
#             if len(gt_bboxes.shape) < 2:
#                 # print(torch.npu.synchronize(),'==================b1.2')
#                 gt_bboxes = gt_bboxes.view(-1, 4)
            
#             self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
#             # self.pos_gt_bboxes = gt_bboxes
#         # print(torch.npu.synchronize(),'==================b2')
#         # print('gt_bboxes.numel(): ', gt_bboxes.numel())
#         # print('self.pos_gt_bboxes: ', self.pos_gt_bboxes.shape, self.pos_gt_bboxes.dtype, self.pos_gt_bboxes.device)

#         if assign_result.labels is not None:
#             # print(torch.npu.synchronize(),'==================b3')
#             self.pos_gt_labels = assign_result.labels[pos_inds]
#             # print(torch.npu.synchronize(),'==================b3.5')
#         else:
#             # print(torch.npu.synchronize(),'==================b4')
#             self.pos_gt_labels = None
#             # print(torch.npu.synchronize(),'==================b4.5')

    @property
    def bboxes(self):
        """torch.Tensor: concatenated positive and negative boxes"""
#         bboxes_sum = self.pos_bboxes.new_zeros(num_samples)
#         return torch.cat([self.pos_bboxes, self.neg_bboxes])
        return (self.pos_bboxes + self.neg_bboxes)
    
    @property
    def inds(self):
        return self.pos_inds, self.neg_inds

    def to(self, device):
        """Change the device of the data inplace.

        Example:
            >>> self = SamplingResult.random()
            >>> print(f'self = {self.to(None)}')
            >>> # xdoctest: +REQUIRES(--gpu)
            >>> print(f'self = {self.to(0)}')
        """
        _dict = self.__dict__
        for key, value in _dict.items():
            if isinstance(value, torch.Tensor):
                _dict[key] = value.to(device)
        return self

    def __nice__(self):
        data = self.info.copy()
        data['pos_bboxes'] = data.pop('pos_bboxes').shape
        data['neg_bboxes'] = data.pop('neg_bboxes').shape
        parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
        body = '    ' + ',\n    '.join(parts)
        return '{\n' + body + '\n}'

    @property
    def info(self):
        """Returns a dictionary of info about the object."""
        return {
            'pos_inds': self.pos_inds,
            'neg_inds': self.neg_inds,
            'pos_bboxes': self.pos_bboxes,
            'neg_bboxes': self.neg_bboxes,
            'pos_is_gt': self.pos_is_gt,
            'num_gts': self.num_gts,
            'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
        }

    @classmethod
    def random(cls, rng=None, **kwargs):
        """
        Args:
            rng (None | int | numpy.random.RandomState): seed or state.
            kwargs (keyword arguments):
                - num_preds: number of predicted boxes
                - num_gts: number of true boxes
                - p_ignore (float): probability of a predicted box assinged to \
                    an ignored truth.
                - p_assigned (float): probability of a predicted box not being \
                    assigned.
                - p_use_label (float | bool): with labels or not.

        Returns:
            :obj:`SamplingResult`: Randomly generated sampling result.

        Example:
            >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA
            >>> self = SamplingResult.random()
            >>> print(self.__dict__)
        """
        from mmdet.core.bbox.samplers.random_sampler import RandomSampler
        from mmdet.core.bbox.assigners.assign_result import AssignResult
        from mmdet.core.bbox import demodata
        rng = demodata.ensure_rng(rng)

        # make probabalistic?
        num = 32
        pos_fraction = 0.5
        neg_pos_ub = -1

        assign_result = AssignResult.random(rng=rng, **kwargs)

        # Note we could just compute an assignment
        bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
        gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)

        if rng.rand() > 0.2:
            # sometimes algorithms squeeze their data, be robust to that
            gt_bboxes = gt_bboxes.squeeze()
            bboxes = bboxes.squeeze()

        if assign_result.labels is None:
            gt_labels = None
        else:
            gt_labels = None  # todo

        if gt_labels is None:
            add_gt_as_proposals = False
        else:
            add_gt_as_proposals = True  # make probabalistic?

        sampler = RandomSampler(
            num,
            pos_fraction,
            neg_pos_ub=neg_pos_ub,
            add_gt_as_proposals=add_gt_as_proposals,
            rng=rng)
        self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
        return self
